import math
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import scipy.linalg
from linearmodels.iv import IV2SLS
from linearmodels.iv import IVLIML


MISSING_CODES = {7, 8, 9, 77, 88, 99, 777, 888, 999, 7777, 8888, 9999}

# Default instrument set (base instruments; interactions with edu_high are created automatically).
# These are shift-share terms built from: region baseline broadband access (2016) × country-year coverage shocks
# from Eurostat (isoc_cbs / isoc_cbt).
INSTRUMENT_BASE_COLS = [
    "z_ss_cov30",
    "z_ss_cov100",
    "z_ss_fttp",
    "z_ss_gbps1",
    "z_ss_vhcn_fx",
    "z_ss_nga",
    "z_ss_docsis31",
]

RANK_TOL = 1e-10


def clean_numeric(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    s = s.mask(s.isin(MISSING_CODES))
    return s


def recode_yes_no(series: pd.Series) -> pd.Series:
    s = clean_numeric(series)
    return s.map({1: 1.0, 2: 0.0})


def absorb_two_way(
    df: pd.DataFrame,
    cols: list[str],
    fe_a: str,
    fe_b: str,
    max_iter: int = 30,
    tol: float = 1e-10,
    weights: pd.Series | None = None,
) -> pd.DataFrame:
    x = df[cols].to_numpy(dtype=float, copy=True)
    a_codes = pd.Categorical(df[fe_a]).codes.astype(np.int32)
    b_codes = pd.Categorical(df[fe_b]).codes.astype(np.int32)
    a_n = int(a_codes.max() + 1)
    b_n = int(b_codes.max() + 1)

    if weights is None:
        w = None
        a_counts = np.bincount(a_codes, minlength=a_n).astype(float)
        b_counts = np.bincount(b_codes, minlength=b_n).astype(float)
    else:
        w = pd.to_numeric(weights, errors="coerce").to_numpy(dtype=float, copy=False)
        w = np.where(np.isfinite(w) & (w > 0), w, np.nan)
        # Replace nonpositive/missing weights with 1.0 for stability in absorption;
        # rows with missing weights should be dropped by caller if needed.
        w = np.where(np.isfinite(w), w, 1.0)
        a_counts = np.bincount(a_codes, weights=w, minlength=a_n).astype(float)
        b_counts = np.bincount(b_codes, weights=w, minlength=b_n).astype(float)

    for _ in range(max_iter):
        prev = x.copy()
        for j in range(x.shape[1]):
            sums = np.bincount(a_codes, weights=(x[:, j] if w is None else (w * x[:, j])), minlength=a_n)
            means = sums / a_counts
            x[:, j] = x[:, j] - means[a_codes]
        for j in range(x.shape[1]):
            sums = np.bincount(b_codes, weights=(x[:, j] if w is None else (w * x[:, j])), minlength=b_n)
            means = sums / b_counts
            x[:, j] = x[:, j] - means[b_codes]
        delta = float(np.nanmax(np.abs(x - prev)))
        if delta < tol:
            break

    return pd.DataFrame(x, columns=cols, index=df.index)


@dataclass(frozen=True)
class OutcomeSpec:
    name: str
    label: str
    kind: str


def build_dataset(root: Path) -> pd.DataFrame:
    ess = pd.read_parquet(root / "outputs" / "ess_r8_r11_min.parquet")
    ss = pd.read_csv(root / "data_external" / "rollout_shiftshare_region_year.csv")

    ess = ess[ess["regunit"].isin([1, 2])].copy()
    ess["region"] = ess["region"].astype(str).str.upper().str.strip()
    ess = ess[ess["region"] != "99999"].copy()
    ess["survey_year"] = pd.to_numeric(ess["survey_year"], errors="coerce").astype("Int64")

    ss["region"] = ss["region"].astype(str).str.upper().str.strip()
    ss["survey_year"] = pd.to_numeric(ss["survey_year"], errors="coerce").astype("Int64")
    for col in INSTRUMENT_BASE_COLS:
        if col in ss.columns:
            ss[col] = pd.to_numeric(ss[col], errors="coerce")
        else:
            ss[col] = np.nan

    df = ess.merge(
        ss[["region", "survey_year", *INSTRUMENT_BASE_COLS]],
        on=["region", "survey_year"],
        how="left",
    )

    # Endogenous
    df["netusoft"] = clean_numeric(df["netusoft"])

    # Outcomes
    df["y_vote"] = recode_yes_no(df["vote"])
    df["y_sgnptit"] = recode_yes_no(df["sgnptit"])
    df["y_pbldmn"] = recode_yes_no(df["pbldmn"])
    df["y_bctprd"] = recode_yes_no(df["bctprd"])
    df["y_contplt"] = recode_yes_no(df["contplt"])
    df["y_badge"] = recode_yes_no(df["badge"])
    df["y_pstplonl"] = recode_yes_no(df["pstplonl"]) if "pstplonl" in df.columns else np.nan
    part_cols = ["y_sgnptit", "y_pbldmn", "y_bctprd", "y_contplt", "y_badge", "y_pstplonl"]
    df["y_part_index"] = df[part_cols].mean(axis=1, skipna=True)
    df.loc[df[part_cols].isna().all(axis=1), "y_part_index"] = np.nan

    # Mechanism
    df["nwspol"] = clean_numeric(df["nwspol"])
    df["log_nwspol"] = (df["nwspol"].clip(lower=0)).map(lambda x: math.log1p(x) if pd.notna(x) else x)

    # Additional mechanism / attitude measures (available in ESS9-11)
    for col in ["polintr", "ppltrst", "stfdem", "trstplt", "trstprl", "trstprt", "actrolga", "prtdgcl", "inprdsc"]:
        if col in df.columns:
            df[col] = clean_numeric(df[col])

    # Placebos
    df["gndr_num"] = clean_numeric(df["gndr"])
    df["gndr"] = df["gndr_num"]
    df["brncntr_bin"] = recode_yes_no(df["brncntr"]) if "brncntr" in df.columns else np.nan

    # Controls
    df["agea"] = clean_numeric(df["agea"])
    df["hinctnta"] = clean_numeric(df["hinctnta"])
    df["domicil"] = clean_numeric(df["domicil"]) if "domicil" in df.columns else np.nan
    df["eduyrs"] = clean_numeric(df["eduyrs"])

    df["edu_high"] = np.where(df["eduyrs"].notna() & (df["eduyrs"] > 12), 1.0, np.nan)
    df.loc[df["eduyrs"].notna() & (df["eduyrs"] <= 12), "edu_high"] = 0.0

    df["ct"] = df["cntry"].astype(str) + "_" + df["survey_year"].astype("Int64").astype(str)

    # Keep usable estimation sample
    df = df[df[INSTRUMENT_BASE_COLS].notna().any(axis=1)].copy()
    df = df[df["edu_high"].notna() & df["netusoft"].notna()].copy()

    return df


def fit_iv_interaction_absorb(df: pd.DataFrame, y: str, controls: list[str]) -> IV2SLS:
    instr_base = [c for c in INSTRUMENT_BASE_COLS if c in df.columns]
    if not instr_base:
        raise ValueError("No instrument columns found in dataset")

    cols = [y, "netusoft", "edu_high", *instr_base, "ct", "region", *controls]
    d = df[cols].dropna().copy()
    if d.empty:
        raise ValueError(f"No rows after dropna for outcome={y}")

    d["netusoft_x_edu"] = d["netusoft"] * d["edu_high"]
    for z in instr_base:
        d[f"{z}_x_edu"] = d[z] * d["edu_high"]

    model_cols = [
        y,
        "netusoft",
        "netusoft_x_edu",
        *instr_base,
        *[f"{z}_x_edu" for z in instr_base],
        *controls,
        "edu_high",
    ]
    r = absorb_two_way(d, model_cols, fe_a="region", fe_b="ct")

    y_r = r[y]
    exog = r[controls + ["edu_high"]]
    endog = r[["netusoft", "netusoft_x_edu"]]
    instr_full = r[[*instr_base, *[f"{z}_x_edu" for z in instr_base]]]
    instr = select_instruments_full_rank(exog=exog, instruments=instr_full, min_cols=endog.shape[1])

    mod = IV2SLS(y_r, exog=exog, endog=endog, instruments=instr)
    return mod.fit(cov_type="clustered", clusters=d["region"])


def summarize_result(res, y: str, label: str) -> dict:
    b1 = float(res.params.get("netusoft", np.nan))
    b2 = float(res.params.get("netusoft_x_edu", np.nan))
    se1 = float(res.std_errors.get("netusoft", np.nan))
    se2 = float(res.std_errors.get("netusoft_x_edu", np.nan))

    diag = res.first_stage.diagnostics
    fs_net = diag.loc["netusoft"] if "netusoft" in diag.index else diag.iloc[0]
    fs_int = diag.loc["netusoft_x_edu"] if "netusoft_x_edu" in diag.index else diag.iloc[-1]

    ar_p = np.nan
    try:
        ar_p = float(res.anderson_rubin.pval)
    except Exception:
        ar_p = np.nan

    return {
        "outcome": y,
        "label": label,
        "nobs": int(res.nobs),
        "coef_lowEdu": b1,
        "se_lowEdu": se1,
        "coef_highEdu": b1 + b2,
        "coef_deltaHighMinusLow": b2,
        "se_deltaHighMinusLow": se2,
        "fsF_netusoft": float(fs_net.get("f.stat", np.nan)),
        "fsF_netusoft_x_edu": float(fs_int.get("f.stat", np.nan)),
        "AR_pvalue": ar_p,
    }


def to_markdown_table(df: pd.DataFrame) -> str:
    cols = [
        "label",
        "nobs",
        "coef_lowEdu",
        "se_lowEdu",
        "coef_highEdu",
        "coef_deltaHighMinusLow",
        "se_deltaHighMinusLow",
        "fsF_netusoft",
        "fsF_netusoft_x_edu",
        "AR_pvalue",
    ]
    d = df[cols].copy()
    for c in ["coef_lowEdu", "se_lowEdu", "coef_highEdu", "coef_deltaHighMinusLow", "se_deltaHighMinusLow"]:
        d[c] = d[c].map(lambda x: f"{x:.4f}" if pd.notna(x) else "")
    for c in ["fsF_netusoft", "fsF_netusoft_x_edu"]:
        d[c] = d[c].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
    d["AR_pvalue"] = d["AR_pvalue"].map(lambda x: f"{x:.4g}" if pd.notna(x) else "")
    return d.to_markdown(index=False)


def to_latex_tabular(df: pd.DataFrame) -> str:
    cols = [
        "label",
        "nobs",
        "coef_lowEdu",
        "se_lowEdu",
        "coef_highEdu",
        "coef_deltaHighMinusLow",
        "se_deltaHighMinusLow",
        "fsF_netusoft",
        "fsF_netusoft_x_edu",
        "AR_pvalue",
    ]
    header = [
        "Outcome",
        "N",
        "b(LowEdu)",
        "se(LowEdu)",
        "b(HighEdu)",
        "Delta(High-Low)",
        "se(Delta)",
        "F(net)",
        "F(net$\\times$edu)",
        "AR p",
    ]
    d = df[cols].copy()
    d["nobs"] = d["nobs"].map(lambda x: f"{x:.0f}" if pd.notna(x) else "")
    for c in ["coef_lowEdu", "se_lowEdu", "coef_highEdu", "coef_deltaHighMinusLow", "se_deltaHighMinusLow"]:
        d[c] = d[c].map(lambda x: f"{x:.4f}" if pd.notna(x) else "")
    for c in ["fsF_netusoft", "fsF_netusoft_x_edu"]:
        d[c] = d[c].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
    d["AR_pvalue"] = d["AR_pvalue"].map(lambda x: f"{x:.4g}" if pd.notna(x) else "")

    out = []
    out.append("\\begin{tabular}{lrrrrrrrrr}")
    out.append("\\toprule")
    out.append(" & ".join(header) + " \\\\")
    out.append("\\midrule")
    for _, row in d.iterrows():
        out.append(" & ".join(str(row[c]) for c in cols) + " \\\\")
    out.append("\\bottomrule")
    out.append("\\end{tabular}")
    return "\n".join(out)


def fit_liml_interaction_absorb(df: pd.DataFrame, y: str, controls: list[str]):
    instr_base = [c for c in INSTRUMENT_BASE_COLS if c in df.columns]
    if not instr_base:
        raise ValueError("No instrument columns found in dataset")

    cols = [y, "netusoft", "edu_high", *instr_base, "ct", "region", *controls]
    d = df[cols].dropna().copy()
    if d.empty:
        raise ValueError(f"No rows after dropna for outcome={y}")

    d["netusoft_x_edu"] = d["netusoft"] * d["edu_high"]
    for z in instr_base:
        d[f"{z}_x_edu"] = d[z] * d["edu_high"]

    model_cols = [
        y,
        "netusoft",
        "netusoft_x_edu",
        *instr_base,
        *[f"{z}_x_edu" for z in instr_base],
        *controls,
        "edu_high",
    ]
    r = absorb_two_way(d, model_cols, fe_a="region", fe_b="ct")

    y_r = r[y]
    exog = r[controls + ["edu_high"]]
    endog = r[["netusoft", "netusoft_x_edu"]]
    instr_full = r[[*instr_base, *[f"{z}_x_edu" for z in instr_base]]]
    instr = select_instruments_full_rank(exog=exog, instruments=instr_full, min_cols=endog.shape[1])

    mod = IVLIML(y_r, exog=exog, endog=endog, instruments=instr)
    return mod.fit(cov_type="clustered", clusters=d["region"])


def select_instruments_full_rank(exog: pd.DataFrame, instruments: pd.DataFrame, min_cols: int) -> pd.DataFrame:
    """
    Linearmodels requires that [exog instruments] has full column rank.
    In some subsamples or robustness specifications, some instrument components can become
    (near-)collinear after FE absorption. This helper drops redundant columns deterministically.
    """
    z = instruments.copy()
    # Drop near-constant or non-finite columns (after absorption these can happen).
    z = z.loc[:, np.isfinite(z.to_numpy(dtype=float)).all(axis=0)].copy()
    if z.shape[1] == 0:
        raise ValueError("No finite instrument columns available")

    z_std = np.nanstd(z.to_numpy(dtype=float), axis=0)
    keep = z_std > RANK_TOL
    z = z.loc[:, keep].copy()
    if z.shape[1] == 0:
        raise ValueError("All instruments are (near-)constant after absorption")

    x = exog.to_numpy(dtype=float, copy=False)
    z_mat = z.to_numpy(dtype=float, copy=False)

    # Residualize instruments on exog to ensure added columns increase rank of [exog instruments].
    beta, *_ = np.linalg.lstsq(x, z_mat, rcond=None)
    z_res = z_mat - x @ beta

    # Pivoted QR to pick a well-conditioned full-rank subset.
    _, r, piv = scipy.linalg.qr(z_res, mode="economic", pivoting=True)
    rank = int(np.sum(np.abs(np.diag(r)) > RANK_TOL))
    if rank == 0:
        raise ValueError("Instrument residuals have rank 0 after partialling out exog")

    selected = [z.columns[int(i)] for i in piv[:rank]]
    if len(selected) < min_cols:
        # Fallback: keep at least min_cols columns in original order.
        selected = list(z.columns[:min_cols])

    return z[selected]


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_dir = root / "outputs"
    out_dir.mkdir(parents=True, exist_ok=True)

    df = build_dataset(root)

    outcomes = [
        OutcomeSpec("y_vote", "Vote (yes/no)", "binary"),
        OutcomeSpec("y_part_index", "Participation index (mean of non-vote items)", "continuous"),
        OutcomeSpec("y_sgnptit", "Signed petition", "binary"),
        OutcomeSpec("y_contplt", "Contacted politician/official", "binary"),
        OutcomeSpec("y_badge", "Worn campaign badge/sticker", "binary"),
        OutcomeSpec("y_pbldmn", "Taken part in public demonstration", "binary"),
        OutcomeSpec("y_bctprd", "Boycotted products", "binary"),
    ]
    controls = ["agea", "gndr", "hinctnta"]

    rows = []
    for spec in outcomes:
        print(f"Fitting: {spec.name}", flush=True)
        try:
            res = fit_iv_interaction_absorb(df, spec.name, controls=controls)
        except Exception as e:
            rows.append(
                {
                    "outcome": spec.name,
                    "label": spec.label,
                    "nobs": 0,
                    "coef_lowEdu": np.nan,
                    "se_lowEdu": np.nan,
                    "coef_highEdu": np.nan,
                    "coef_deltaHighMinusLow": np.nan,
                    "se_deltaHighMinusLow": np.nan,
                    "fsF_netusoft": np.nan,
                    "fsF_netusoft_x_edu": np.nan,
                    "AR_pvalue": np.nan,
                    "error": str(e),
                }
            )
            continue
        rows.append(summarize_result(res, spec.name, spec.label))

    tab = pd.DataFrame(rows)
    (out_dir / "iv_shiftshare_absorb_main_table.csv").write_text(tab.to_csv(index=False), encoding="utf-8")
    (out_dir / "iv_shiftshare_absorb_main_table.md").write_text(to_markdown_table(tab) + "\n", encoding="utf-8")

    # LIML robustness table (same spec)
    liml_rows = []
    for spec in outcomes:
        print(f"Fitting LIML: {spec.name}", flush=True)
        try:
            res = fit_liml_interaction_absorb(df, spec.name, controls=controls)
        except Exception as e:
            liml_rows.append(
                {
                    "outcome": spec.name,
                    "label": spec.label,
                    "nobs": 0,
                    "coef_lowEdu": np.nan,
                    "se_lowEdu": np.nan,
                    "coef_highEdu": np.nan,
                    "coef_deltaHighMinusLow": np.nan,
                    "se_deltaHighMinusLow": np.nan,
                    "fsF_netusoft": np.nan,
                    "fsF_netusoft_x_edu": np.nan,
                    "AR_pvalue": np.nan,
                    "error": str(e),
                }
            )
            continue
        liml_rows.append(summarize_result(res, spec.name, spec.label))
    liml_tab = pd.DataFrame(liml_rows)
    (out_dir / "iv_shiftshare_absorb_liml_main_table.csv").write_text(liml_tab.to_csv(index=False), encoding="utf-8")
    (out_dir / "iv_shiftshare_absorb_liml_main_table.md").write_text(to_markdown_table(liml_tab) + "\n", encoding="utf-8")

    paper_tables = root / "paper_joc" / "tables"
    paper_tables.mkdir(parents=True, exist_ok=True)

    # Mechanism proxies for the main manuscript (compact)
    mechanisms_main = [
        ("log_nwspol", "mechanism: log(1+news minutes)", ["agea", "gndr", "hinctnta"]),
        ("y_pstplonl", "mechanism: posted/shared politics online (1/0)", ["agea", "gndr", "hinctnta"]),
        ("polintr", "attitude: political interest", ["agea", "gndr", "hinctnta"]),
    ]
    mech_rows = []
    for y, label, ctrls in mechanisms_main:
        if y not in df.columns:
            continue
        res = fit_iv_interaction_absorb(df, y, controls=ctrls)
        mech_rows.append(summarize_result(res, y, label))
    mech = pd.DataFrame(mech_rows)
    (out_dir / "iv_shiftshare_absorb_mechanisms.csv").write_text(mech.to_csv(index=False), encoding="utf-8")
    (out_dir / "iv_shiftshare_absorb_mechanisms.md").write_text(to_markdown_table(mech) + "\n", encoding="utf-8")
    (paper_tables / "iv_shiftshare_mechanisms.tex").write_text(to_latex_tabular(mech) + "\n", encoding="utf-8")

    # Extended attitude outcomes for supplement
    attitudes_specs = [
        ("actrolga", "attitude: able to take active role in political group", ["agea", "gndr", "hinctnta"]),
        ("prtdgcl", "attitude: party closeness", ["agea", "gndr", "hinctnta"]),
        ("stfdem", "attitude: satisfaction with democracy", ["agea", "gndr", "hinctnta"]),
        ("trstplt", "attitude: trust in politicians", ["agea", "gndr", "hinctnta"]),
        ("trstprl", "attitude: trust in parliament", ["agea", "gndr", "hinctnta"]),
        ("trstprt", "attitude: trust in parties", ["agea", "gndr", "hinctnta"]),
        ("ppltrst", "attitude: general interpersonal trust", ["agea", "gndr", "hinctnta"]),
        ("inprdsc", "social capital: network size for intimate discussion", ["agea", "gndr", "hinctnta"]),
    ]
    att_rows = []
    for y, label, ctrls in attitudes_specs:
        if y not in df.columns:
            continue
        try:
            res = fit_iv_interaction_absorb(df, y, controls=ctrls)
        except Exception:
            continue
        att_rows.append(summarize_result(res, y, label))
    attitudes = pd.DataFrame(att_rows)
    (out_dir / "iv_shiftshare_absorb_attitudes.csv").write_text(attitudes.to_csv(index=False), encoding="utf-8")
    (out_dir / "iv_shiftshare_absorb_attitudes.md").write_text(to_markdown_table(attitudes) + "\n", encoding="utf-8")
    (paper_tables / "iv_shiftshare_attitudes.tex").write_text(to_latex_tabular(attitudes) + "\n", encoding="utf-8")

    # Balance / placebo checks for supplement
    balance_specs = [
        ("brncntr_bin", "placebo: born in country (1/0)", ["agea", "gndr", "hinctnta"]),
        ("gndr", "placebo: gender (numeric code)", ["agea", "hinctnta"]),
        ("agea", "balance: age", ["gndr", "hinctnta"]),
        ("hinctnta", "balance: income decile", ["agea", "gndr"]),
        ("domicil", "balance: domicile (urban/rural)", ["agea", "gndr", "hinctnta"]),
    ]
    bal_rows = []
    for y, label, ctrls in balance_specs:
        if y not in df.columns:
            continue
        try:
            res = fit_iv_interaction_absorb(df, y, controls=ctrls)
        except Exception:
            continue
        bal_rows.append(summarize_result(res, y, label))
    balance = pd.DataFrame(bal_rows)
    (out_dir / "iv_shiftshare_absorb_balance.csv").write_text(balance.to_csv(index=False), encoding="utf-8")
    (out_dir / "iv_shiftshare_absorb_balance.md").write_text(to_markdown_table(balance) + "\n", encoding="utf-8")
    (paper_tables / "iv_shiftshare_balance.tex").write_text(to_latex_tabular(balance) + "\n", encoding="utf-8")

    txt = []
    txt.append("Shift-share rollout-IV with absorbed region FE + country×year FE")
    txt.append("Instrument: region baseline broadband access (2016) × country-year high-speed coverage shocks (Eurostat isoc_cbs)")
    txt.append(f"Rows in estimation pool: {len(df):,}")
    txt.append(f"Countries: {df['cntry'].nunique()}  Regions: {df['region'].nunique()}  Years: {sorted(df['survey_year'].unique().tolist())}")
    txt.append("Outputs:")
    txt.append(f"- {out_dir / 'iv_shiftshare_absorb_main_table.md'}")
    txt.append(f"- {out_dir / 'iv_shiftshare_absorb_liml_main_table.md'}")
    txt.append(f"- {out_dir / 'iv_shiftshare_absorb_mechanisms.md'}")
    txt.append(f"- {out_dir / 'iv_shiftshare_absorb_attitudes.md'}")
    txt.append(f"- {out_dir / 'iv_shiftshare_absorb_balance.md'}")
    (out_dir / "iv_shiftshare_absorb_checks.txt").write_text("\n".join(txt) + "\n", encoding="utf-8")
    print(f"Wrote: {out_dir / 'iv_shiftshare_absorb_main_table.md'}")


if __name__ == "__main__":
    main()
